library(caret) 
## Loading required package: ggplot2
## Loading required package: lattice
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.1.1 ──
## ✔ broom        1.0.5     ✔ rsample      1.2.0
## ✔ dials        1.2.0     ✔ tibble       3.2.1
## ✔ dplyr        1.1.4     ✔ tidyr        1.3.1
## ✔ infer        1.0.5     ✔ tune         1.1.2
## ✔ modeldata    1.2.0     ✔ workflows    1.1.3
## ✔ parsnip      1.1.1     ✔ workflowsets 1.0.1
## ✔ purrr        1.0.2     ✔ yardstick    1.2.0
## ✔ recipes      1.0.8
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ purrr::discard()         masks scales::discard()
## ✖ dplyr::filter()          masks stats::filter()
## ✖ dplyr::lag()             masks stats::lag()
## ✖ purrr::lift()            masks caret::lift()
## ✖ yardstick::precision()   masks caret::precision()
## ✖ yardstick::recall()      masks caret::recall()
## ✖ yardstick::sensitivity() masks caret::sensitivity()
## ✖ yardstick::specificity() masks caret::specificity()
## ✖ recipes::step()          masks stats::step()
## • Search for functions across packages at https://www.tidymodels.org/find/
library(splines)
library(mgcv)
## Loading required package: nlme
## 
## Attaching package: 'nlme'
## The following object is masked from 'package:dplyr':
## 
##     collapse
## This is mgcv 1.9-0. For overview type 'help("mgcv-package")'.
library(pdp)
## 
## Attaching package: 'pdp'
## The following object is masked from 'package:purrr':
## 
##     partial
library(earth)
## Loading required package: Formula
## Loading required package: plotmo
## Loading required package: plotrix
## 
## Attaching package: 'plotrix'
## The following object is masked from 'package:scales':
## 
##     rescale
## Loading required package: TeachingDemos
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ forcats   1.0.0     ✔ readr     2.1.4
## ✔ lubridate 1.9.2     ✔ stringr   1.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ readr::col_factor() masks scales::col_factor()
## ✖ nlme::collapse()    masks dplyr::collapse()
## ✖ purrr::discard()    masks scales::discard()
## ✖ dplyr::filter()     masks stats::filter()
## ✖ stringr::fixed()    masks recipes::fixed()
## ✖ dplyr::lag()        masks stats::lag()
## ✖ purrr::lift()       masks caret::lift()
## ✖ pdp::partial()      masks purrr::partial()
## ✖ readr::spec()       masks yardstick::spec()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(ggplot2)
library(ISLR)
library(pls)
## 
## Attaching package: 'pls'
## 
## The following object is masked from 'package:caret':
## 
##     R2
## 
## The following object is masked from 'package:stats':
## 
##     loadings
library(corrplot)
## corrplot 0.92 loaded
## 
## Attaching package: 'corrplot'
## 
## The following object is masked from 'package:pls':
## 
##     corrplot
load("recovery.RData")
dat$gender <- as.factor(dat$gender)
dat$race <- as.factor(dat$race)
dat$smoking <- as.factor(dat$smoking)
dat$hypertension <- as.factor(dat$hypertension)
dat$diabetes <- as.factor(dat$diabetes)
dat$vaccine <- as.factor(dat$vaccine)
dat$severity <- as.factor(dat$severity)
dat$study <- as.factor(dat$study)

dat <- dat %>%
  select(-id) %>% 
  mutate(
    gender = case_when(
      dat$gender == 1 ~ "Male",
      TRUE ~ "Female"),
    race = case_when(
      dat$race == 1 ~ "White",
      dat$race == 2 ~ "Asian",
      dat$race == 3 ~ "Black",
      TRUE ~ "Hispanic"),
    smoking = case_when(
      dat$smoking == 0 ~ "Never Smoked",
      dat$smoking == 1 ~ "Former Smoker",
      TRUE ~ "Current Smoker")) %>% 
  mutate_if(is.character, as.factor) %>% 
  mutate(race = relevel(race, ref = "White"),
         smoking = relevel(smoking, ref = "Never Smoked"),
         study = relevel(study, ref = "B"))

# Continuous variables
continuous_vars <- c("age","height", "weight", "bmi", "SBP", "LDL")

for (var in continuous_vars) {
  plot <- ggplot(dat, aes_string(x = var, y = "recovery_time")) +
    geom_point() +
    labs(x = var, y = "Time to Recovery (days)", title = paste(var, "vs. Time to Recovery")) +
    theme_classic()
  print(plot)
}
## Warning: `aes_string()` was deprecated in ggplot2 3.0.0.
## ℹ Please use tidy evaluation idioms with `aes()`.
## ℹ See also `vignette("ggplot2-in-packages")` for more information.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.

# Discrete variables
discrete_vars <- c("gender", "race", "smoking", "hypertension", "diabetes", "vaccine", "severity", "study")

for (var in discrete_vars) {
  plot <- ggplot(dat, aes_string(x = var, y = "recovery_time")) +
    geom_boxplot() +
    labs(x = var, y = "Time to Recovery (days)", title = paste(var, "vs. Time to Recovery")) +
    theme_classic()
  print(plot)
}

dat_continuous <- dat %>% 
  select(c(age, height, weight, bmi, SBP, LDL))

corrplot(cor(dat_continuous), method = 'number', type = 'lower') 

# Variables to include in subset: height, weight, vaccine, severity, study

dat_subset <- dat %>% 
  select(c(height, weight, vaccine, severity, study, recovery_time))

dat %>% ggplot(aes(x = age, y = recovery_time)) +
  geom_jitter() + geom_smooth() + theme_classic()
## `geom_smooth()` using method = 'gam' and formula = 'y ~ s(x, bs = "cs")'

set data

set.seed(1)
trainIndex <- createDataPartition(dat$recovery_time, p = 0.8, list = FALSE)
training_data <- dat[trainIndex, ]
testing_data <- dat[-trainIndex, ]

set.seed(1)
trainIndex_sub <- createDataPartition(dat_subset$recovery_time, p = 0.8, list = FALSE)
training_data_sub <- dat_subset[trainIndex, ]
testing_data_sub <- dat_subset[-trainIndex, ]


ctrl_SE <- trainControl(method = "repeatedcv",
                      number = 10,
                      repeats = 5,
                      selectionFunction = "oneSE")

ctrl_best <- trainControl(method = "repeatedcv",
                      number = 10,
                      repeats = 5,
                      selectionFunction = "best")

x <- model.matrix(recovery_time ~ ., training_data)[, -1]
y <- training_data$recovery_time

x_sub <- model.matrix(recovery_time ~ ., training_data_sub)[, -1]
y_sub <- training_data_sub$recovery_time

Train the LASSO model

set.seed(1)

lasso_model <- train(
  x = x,
  y = y,
  data = training_data,
  method = "glmnet",
  trControl = ctrl_SE,
  tuneGrid = expand.grid(alpha = 1, 
                         lambda = exp(seq(-6, -1, length = 100))),
  standardize = T
)

plot(lasso_model, xTrans = log)

best_lambda <- lasso_model$bestTune$lambda
coef(lasso_model$finalModel, lasso_model$bestTune$lambda)
## 18 x 1 sparse Matrix of class "dgCMatrix"
##                                  s1
## (Intercept)           -840.04385489
## age                      0.16823163
## genderMale              -2.59075696
## raceAsian                3.50404233
## raceBlack                .         
## raceHispanic            -0.38081627
## smokingCurrent Smoker    2.97646686
## smokingFormer Smoker     1.72027649
## height                   4.86538187
## weight                  -5.41777863
## bmi                     17.43538534
## hypertension1            1.69648547
## diabetes1               -1.46759957
## SBP                      0.02233705
## LDL                     -0.02000477
## vaccine1                -6.58454252
## severity1                6.81026029
## studyA                  -4.75766922
# LASSO model Subset
set.seed(1)

lasso_fit_sub <- train(
  x = x_sub,
  y = y_sub,
  data = training_data_sub,
  method = "glmnet",
  trControl = ctrl_SE,
  tuneGrid = expand.grid(alpha = 1, 
                         lambda = exp(seq(-6, 2, length = 100))),
  standardize = T
)
## Warning in nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo,
## : There were missing values in resampled performance measures.
plot(lasso_fit_sub , xTrans = log)

best_lambda <- lasso_fit_sub$bestTune$lambda
coef(lasso_fit_sub$finalModel, lasso_fit_sub$bestTune$lambda)
## 6 x 1 sparse Matrix of class "dgCMatrix"
##                     s1
## (Intercept) 88.9382606
## height      -0.3782789
## weight       0.2434927
## vaccine1    -2.5641618
## severity1    0.7450711
## studyA      -0.5830858

train the ridge

set.seed(1)

ridge_model <- train(x = x,
                    y = y,
                   data = training_data,
                   method = "glmnet",
                   tuneGrid = expand.grid(alpha = 0,
                                          lambda = exp(seq(-6, 10, length=100))),
                   trControl = ctrl_best,
                   standardize = T)
## Warning in nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo,
## : There were missing values in resampled performance measures.
plot(ridge_model, xTrans = log)

best_lambda_ridge = ridge_model$bestTune$lambda

# coefficients in the final model
coef(ridge_model$finalModel, s = ridge_model$bestTune$lambda)
## 18 x 1 sparse Matrix of class "dgCMatrix"
##                                  s1
## (Intercept)           -101.92788195
## age                      0.19639986
## genderMale              -2.71878900
## raceAsian                4.01116555
## raceBlack               -0.15248392
## raceHispanic            -0.98352878
## smokingCurrent Smoker    3.36685579
## smokingFormer Smoker     1.91249121
## height                   0.49957702
## weight                  -0.78408717
## bmi                      4.11762834
## hypertension1            1.68868623
## diabetes1               -1.84842122
## SBP                      0.04152787
## LDL                     -0.02903930
## vaccine1                -6.71601476
## severity1                6.91370346
## studyA                  -4.97385571
# Ridge Subset
set.seed(1)

ridge_fit_sub <- train(x = x_sub,
                   y = y_sub,
                   data = training_data_sub,
                   method = "glmnet",
                   tuneGrid = expand.grid(alpha = 0,
                                          lambda = exp(seq(-6, 10, length=100))),
                   trControl = ctrl_best,
                   standardize = T)
## Warning in nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo,
## : There were missing values in resampled performance measures.
plot(ridge_fit_sub, xTrans = log)

ridge_fit_sub$bestTune
##    alpha   lambda
## 41     0 1.591451
coef(ridge_fit_sub$finalModel, s = ridge_fit_sub$bestTune$lambda)
## 6 x 1 sparse Matrix of class "dgCMatrix"
##                      s1
## (Intercept) 133.1115977
## height       -0.7673816
## weight        0.5733665
## vaccine1     -6.4823273
## severity1     7.1377216
## studyA       -4.8237001

train the elastic net

set.seed(1)

enet_model <- train(x = x,
                   y = y,
                  data = training_data,
                  method = "glmnet",
                  tuneGrid = expand.grid(alpha = seq(0, 1, length = 21), 
                                         lambda = exp(seq(-5, 5, length = 100))),
                  trControl = ctrl_best,
                  standardize = T)
## Warning in nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo,
## : There were missing values in resampled performance measures.
 enet_model$bestTune
##      alpha      lambda
## 1501  0.75 0.006737947
myCol <- rainbow(25)
myPar <- list(superpose.symbol = list(col = myCol),
              superpose.line = list(col = myCol))

plot(enet_model, par.settings = myPar)

# coefficients in the final model
coef(enet_model$finalModel, enet_model$bestTune$lambda)
## 18 x 1 sparse Matrix of class "dgCMatrix"
##                                  s1
## (Intercept)           -2.002609e+03
## age                    1.790879e-01
## genderMale            -2.947692e+00
## raceAsian              3.768215e+00
## raceBlack             -4.163253e-01
## raceHispanic          -7.390642e-01
## smokingCurrent Smoker  3.556445e+00
## smokingFormer Smoker   2.179138e+00
## height                 1.172685e+01
## weight                -1.267781e+01
## bmi                    3.825160e+01
## hypertension1          1.903922e+00
## diabetes1             -1.662082e+00
## SBP                    2.140266e-02
## LDL                   -2.939422e-02
## vaccine1              -6.793896e+00
## severity1              7.273653e+00
## studyA                -4.878764e+00
# Elastic Net Subset
set.seed(1)

enet_fit_sub <- train(x = x_sub,
                  y = y_sub,
                  data = training_data_sub,
                  method = "glmnet",
                  tuneGrid = expand.grid(alpha = seq(0, 1, length = 21),
                                         lambda = exp(seq(-5, 5, length = 100))),
                  trControl = ctrl_best,
                  standardize = T)
## Warning in nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo,
## : There were missing values in resampled performance measures.
enet_fit_sub$bestTune
##    alpha   lambda
## 55     0 1.575457
myCol <- rainbow(25)
myPar <- list(superpose.symbol = list(col = myCol),
              superpose.line = list(col = myCol)) 
plot(enet_fit_sub, par.settings = myPar)

coef(enet_fit_sub$finalModel, enet_fit_sub$bestTune$lambda)
## 6 x 1 sparse Matrix of class "dgCMatrix"
##                      s1
## (Intercept) 133.1787123
## height       -0.7679694
## weight        0.5738241
## vaccine1     -6.4862640
## severity1     7.1420582
## studyA       -4.8267005

train the partial least squares

set.seed(1)

pls_model <- train(x = x,
                   y = y,
                   method = "pls",
                   tuneGrid = data.frame(ncomp = 1:15),
                   trControl = ctrl_best,
                    preProcess = c("center", "scale"))

ggplot(pls_model, highlight = TRUE) + theme_classic()

# Partial Least Squares
set.seed(1)

pls_fit_sub <- train(x = x_sub,
                   y = y_sub,
                   method = "pls",
                   tuneGrid = data.frame(ncomp = 1:6),
                   trControl = ctrl_best,
                   preProcess = c("center", "scale"))

ggplot(pls_fit_sub, highlight = TRUE) + theme_classic()

train the MARS

set.seed(1)

mars_grid <- expand.grid(degree = 1:3, 
                         nprune = 2:15)

mars_model <- train(x = x,
                   y = y,
                 method = "earth",
                 tuneGrid = mars_grid,
                 metric = "RMSE",
                 trControl = ctrl_best)

ggplot(mars_model) + theme_classic()

mars_model$bestTune
##    nprune degree
## 16      3      2
coef(mars_model$finalModel)
##          (Intercept)          h(bmi-30.7) h(bmi-30.7) * studyA 
##             38.85227             27.39024            -20.91526
# MARS Sub
set.seed(1)

mars_fit_sub <- train(x = x_sub,
                  y = y_sub,
                 method = "earth",
                 tuneGrid = mars_grid,
                 metric = "RMSE",
                 trControl = ctrl_best)

ggplot(mars_fit_sub) + theme_classic()

mars_fit_sub$bestTune
##    nprune degree
## 41     14      3
coef(mars_fit_sub$finalModel)
##                               (Intercept) 
##                               36.54169610 
##                           h(height-159.6) 
##                               -0.45993611 
##                           h(159.6-height) 
##                               10.10337216 
##          h(159.6-height) * h(weight-81.3) 
##                                6.27108983 
##          h(159.6-height) * h(81.3-weight) 
##                               -0.21300902 
## h(159.6-height) * h(weight-81.3) * studyA 
##                               -5.89235302 
##                            h(weight-77.8) 
##                                1.08027270 
##          h(171.6-height) * h(weight-77.8) 
##                                0.53281219 
##                                  vaccine1 
##                               -6.58139206 
## h(171.6-height) * h(weight-77.8) * studyA 
##                               -0.40470585 
##                                 severity1 
##                                7.48014816 
##          h(height-159.6) * h(87.3-weight) 
##                                0.07525588 
##                  h(159.6-height) * studyA 
##                               -6.23583295
resamples <- resamples(list(lasso = lasso_model, 
                            enet = enet_model,
                            ridge = ridge_model,
                            pls = pls_model,
                            mars = mars_model,
                            lasso_sub = lasso_fit_sub, 
                            enet_sub = enet_fit_sub,
                            ridge_sub = ridge_fit_sub,
                            pls_sub = pls_fit_sub,
                            mars_sub = mars_fit_sub))
bwplot(resamples, metric = "RMSE")

# Prepare the test data for predictions
x_test <- model.matrix(recovery_time ~ ., testing_data)[, -1]
x_test_sub <- model.matrix(recovery_time ~ ., testing_data_sub)[, -1]


# Create a tibble to store the model names and test RMSE values
test_RMSE <- tibble(
  Model = c("LASSO", "Elastic Net", "Ridge", "PLS", "MARS",
            "LASSO (Subset)", "Elastic Net (Subset)", "Ridge (Subset)", "PLS (Subset)", "MARS (Subset)"),
  RMSE = c(
    postResample(predict(lasso_model, newdata = x_test), testing_data$recovery_time)[1],
    postResample(predict(enet_model, newdata = x_test), testing_data$recovery_time)[1],
    postResample(predict(ridge_model, newdata = x_test), testing_data$recovery_time)[1],
    postResample(predict(pls_model, newdata = x_test), testing_data$recovery_time)[1],
    postResample(predict(mars_model, newdata = x_test), testing_data$recovery_time)[1],
    postResample(predict(lasso_fit_sub, newdata = x_test_sub), testing_data_sub$recovery_time)[1],
    postResample(predict(enet_fit_sub, newdata = x_test_sub), testing_data_sub$recovery_time)[1],
    postResample(predict(ridge_fit_sub, newdata = x_test_sub), testing_data_sub$recovery_time)[1],
    postResample(predict(pls_fit_sub, newdata = x_test_sub), testing_data_sub$recovery_time)[1],
    postResample(predict(mars_fit_sub, newdata = x_test_sub), testing_data_sub$recovery_time)[1]
  )
)

test_RMSE %>% arrange(RMSE)
## # A tibble: 10 × 2
##    Model                 RMSE
##    <chr>                <dbl>
##  1 MARS (Subset)         17.8
##  2 MARS                  17.9
##  3 PLS                   18.4
##  4 Elastic Net           18.4
##  5 LASSO                 19.1
##  6 Ridge                 19.8
##  7 Ridge (Subset)        20.4
##  8 Elastic Net (Subset)  20.4
##  9 PLS (Subset)          20.4
## 10 LASSO (Subset)        20.8